#!/usr/bin/env python
# coding=utf-8
'''
Description  :  
Author       : Boxin Zhang, Azure-Tang
Version      : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. 
'''
import re
import sys
import threading

import torch
import torch.distributed as dist
from torch import nn
import itertools
import time
import enum
from transformers import (
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    MinPLogitsWarper,
    TypicalLogitsWarper,
    EpsilonLogitsWarper,
    EtaLogitsWarper,
)

from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, translate_name_to_gguf
from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
if not torch.xpu.is_available():
    from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
# from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket

warm_uped = False
CUR_DEVICE = None
W8A8_ENABLE = False
Q4_GGUF_LODER = None
_USE_NPU_GRAPH = False
_MAX_DECODE_PROFILE = 1
WARM_UP_SKIP_CNT = [1, 1]
_SPECULATE_STEP = 1

try:
    import torch_npu
    use_torch_npu = torch_npu.npu.is_available()
    from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
except:
    use_torch_npu = False

def get_use_npu_graph():
    assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
    return _USE_NPU_GRAPH

from enum import StrEnum

class StatKey(StrEnum):
    Embedding = "Embedding"
    GraphCapture = "GraphCapture"
    GraphReplay = "GraphReplay"
    ExpertsForward1 = "ExpertsForward1"
    ExpertsForward2 = "ExpertsForward2"
    CPUExperts = "CPUExperts"
    GraphDestroy = "GraphDestroy"
    DecodeOneTokenPost = "DecodeOneTokenPost"
    DecodeOneToken = "DecodeOneToken"
    GraphInit = "GraphInit"

class TimeStat:
    def __init__(self):
        # open_status = os.environ["KT_PERF_STAT"] if "KT_PERF_STAT" in os.environ else "0"
        # if open_status == "0":
        #     self.on = False
        # else:
        #     self.on = True
        self.on = True
        self.prefill_stats = dict()
        self.decode_stats = dict()
        for key in StatKey:
            self.prefill_stats[key] = StatItem()
            self.decode_stats[key] = StatItem()
        self.reset_all()

    def record_start_time(self):
        start_time = time.time_ns()
        return start_time

    def add_time_stat(self, key: StatKey, time_ns, is_prefill):
        if not key:
            return
        # torch.cuda.synchronize()
        cost = time.time_ns() - time_ns
        if is_prefill:
            item = self.prefill_stats[key]
        else:
            item = self.decode_stats[key]
        item.add_item(cost)

    def print_all(self):
        # rank = f"[rank:{torch.distributed.get_rank()}]"
        rank = f"[rank:0]"
        msg = f"\n{rank} Prefill Time Stat\n"
        msg += rank + " {:27}{:>15}{:>15}{:>15}{:>15}{:>15}\n".format("", "min(ms)", "max(ms)", "avg(ms)", "count", "total(ms)")
        for key, value in self.prefill_stats.items():
            msg += rank + f" {key.value:<25}:{value.get_stat()}\n"
        msg += f"\n{rank} Decode Time Stat\n"
        msg += rank + " {:27}{:>15}{:>15}{:>15}{:>15}{:>15}\n".format("", "min(ms)", "max(ms)", "avg(ms)", "count", "total(ms)")
        for key, value in self.decode_stats.items():
            msg += rank + f" {key.value:<25}:{value.get_stat()}\n"
        print(msg)

    def reset_all(self):
        for _, value in self.prefill_stats.items():
            value.reset()
        for _, value in self.decode_stats.items():
            value.reset()


class StatItem:
    def __init__(self):
        self.min_time = 100000000
        self.max_time = 0
        self.total_time_ns = 0
        self.count = 0

    def add_item(self, cost_time_ns):
        self.count += 1
        self.total_time_ns += cost_time_ns
        self.min_time = min(self.min_time, cost_time_ns)
        self.max_time = max(self.max_time, cost_time_ns)

    def reset(self):
        self.min_time = 100000000
        self.max_time = 0
        self.total_time_ns = 0
        self.count = 0

    def get_stat(self):
        min_time = self.min_time / 1000 / 1000
        max_time = self.max_time / 1000 / 1000
        if self.count != 0:
            avg_time = self.total_time_ns / self.count / 1000 / 1000
        else:
            avg_time = 0
        total = self.total_time_ns / 1000 / 1000
        return f"{min_time:15.2f}{max_time:15.2f}{avg_time:15.2f}{self.count:15}{total:15.2f}"


timeStat = TimeStat()


def get_free_ports(n: int, continue_prot: list):
    sockets = []
    ports = []
    for _ in range(n):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind(("", 0)) 
        port = s.getsockname()[1]
        if port in continue_prot:
            s.close()
            continue
        ports.append(port)
        sockets.append(s)
    for s in sockets:
        s.close()
    return ports

def get_current_device():
    if use_torch_npu:
        return f"npu:{torch.npu.current_device()}"
    else:
        return f"cuda:{torch.npu.current_device()}"

def get_compute_capability(device:torch.device = None):
    if use_torch_npu:
        return 0
    if torch.cuda.is_available():
        if device is None:
            num_gpus = torch.cuda.device_count()
            min_compute_capability_major = 100
            for gpu_id in range(num_gpus):
                gpu_props = torch.cuda.get_device_properties(gpu_id)
                min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
            return min_compute_capability_major
        else:
            return torch.cuda.get_device_properties(device)

def set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        if hasattr(cur_mod, s):
            cur_mod = getattr(cur_mod, s)
        else: # nn.ModuleList or nn.ModuleList
            cur_mod=cur_mod[int(s)]
    if hasattr(cur_mod, tokens[-1]):
        setattr(cur_mod, tokens[-1], module)
    else: # nn.ModuleList or nn.ModuleList
        cur_mod[int(tokens[-1])] = module

def set_param(module: nn.Module, name: str, weights: torch.Tensor):
    
    param=nn.parameter.Parameter(weights, requires_grad=False)
    if isinstance(module, nn.Linear) and len(weights.shape)==1:
        param.unsqueeze_(0)
    setattr(module, name, param)

def get_device(gguf_module_key:str, device_map:dict):
    if gguf_module_key in device_map:
        return device_map[gguf_module_key]["generate_device"]
    else:
        return "cuda"

def get_all_used_cuda_device(device_map:dict):
    all_device_list = set()
    for key in device_map:
        all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
        all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
    if "cpu" in all_device_list:
        all_device_list.remove("cpu")
    if use_torch_npu:
        all_device_list = set([device.replace('cuda', 'npu') for device in all_device_list])
    all_device_list = list(all_device_list)
    return all_device_list

def load_cur_state_dict_npu(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="npu"):
    prefix = prefix.replace("orig_module.", "")
    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
    local_state = {k: v for k, v in local_name_params if v is not None}
    for name, param in local_state.items():
        key = prefix + name
        translated_key = translate_name_to_gguf(key)
        # TODO: Merge all loader.
        # I know this is ugly but lets do it for now.
        if gguf_loader.safetensor_loader is not None:
            load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
            tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
        else:
            load_dequantized_tensor = gguf_loader.load_gguf_tensor
            tensor_file_map = gguf_loader.tensor_file_map
        
        if translated_key in tensor_file_map:
            target_dtype = torch.get_default_dtype()
            device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
            # Todo need fix
            device = "cpu" if "embd" in translated_key else get_current_device()
            print(f"loading layer {translated_key} to {device}")
            torch.cuda.empty_cache()
            weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
            set_param(module, name, weights)
            del weights
        else:
            #print(load_config.tensor_file_map.keys())
            raise Exception(f"can't find {translated_key} in GGUF file!")

def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
    if use_torch_npu:
        load_cur_state_dict_npu(module, gguf_loader, prefix, device)
        return

    prefix = prefix.replace("orig_module.", "")
    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
    local_state = {k: v for k, v in local_name_params if v is not None}
    for name, param in local_state.items():
        key = prefix + name
        translated_key = key
        
        # TODO: Merge all loader.
        # I know this is ugly but lets do it for now.
        if isinstance(gguf_loader, SafeTensorLoader):
            load_dequantized_tensor = gguf_loader.load_dequantized_tensor
        else:
            load_dequantized_tensor = gguf_loader.load_gguf_tensor
            tensor_file_map = gguf_loader.tensor_file_map
        
        if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
            target_dtype = torch.get_default_dtype()
            device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
            print(f"loading {translated_key} to {device}")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            elif torch.xpu.is_available():
                torch.xpu.empty_cache()
            if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key):
                attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype)
                attn_k_b = attn_k_b.transpose(1, 2).contiguous()
                attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
                kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
                kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()
                set_param(module, name, kv_b_proj)
                del attn_k_b
                del attn_v_b
            else:
                weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
                set_param(module, name, weights)
                del weights
        else:
            #print(load_config.tensor_file_map.keys())
            raise Exception(f"can't find {translated_key} in GGUF file!")

  
def sync_all_device(all_device_list):
    for device in all_device_list:
        if "cuda" in device.lower():
            torch.cuda.synchronize(device)
        elif "xpu" in device.lower():
            torch.xpu.synchronize(device)
        elif use_torch_npu:
            torch_npu.synchronize(device)
        else:
            raise RuntimeError("The device {} is not available".format(device))

torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}

def xpu_fp16_model(config):
    # This function is to check if we run this model on XPU with FP16 dtype
    if not torch.xpu.is_available():
        return False
    if config.architectures[0] == "DeepseekV3ForCausalLM":
        return True
    if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096:
        # Qwen3-30B seems have precision issue with FP16
        # so we only use FP16 for Qwen3-235B now
        return True
    return False

def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
    #print(f"recursively loading weights {prefix}")
    if not isinstance(module, base_operator.BaseInjectedModule):
        load_cur_state_dict(module, gguf_loader, prefix, device=device)
        for name, child in module._modules.items():
            load_weights(child, gguf_loader, prefix+name+".", device=device)
    else:
        module.load()

def tf_logits_warper(generation_config):
        """
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
        """

        # instantiate warpers list
        warpers = LogitsProcessorList()

        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
        # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
        if generation_config.num_beams > 1:
            if isinstance(generation_config._eos_token_tensor, list):
                min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
            elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
                min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
            else:
                min_tokens_to_keep = 2
        else:
            min_tokens_to_keep = 1

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        if generation_config.temperature is not None and generation_config.temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(generation_config.temperature))
        if generation_config.top_k is not None and generation_config.top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
        if generation_config.top_p is not None and generation_config.top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
        if generation_config.min_p is not None:
            # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
            warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
        if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
            warpers.append(
                TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
            )
        if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
            warpers.append(
                EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
            )
        if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
            warpers.append(
               EtaLogitsWarper(
                    epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
                )
            )
        # `LogitNormalization` should always be the last logit processor, when present
        if generation_config.renormalize_logits is True:
            warpers.append(LogitNormalization())
        return warpers
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
                         mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
                         num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None,
                         static_cache = None, draft_model=None, draft_cache=None):
    import os
    
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch._dynamo.config.suppress_errors = True
    batch_size, seq_length = inputs.shape
    device_map = model.gguf_loader.tensor_device_map
    if use_torch_npu:
        CUR_DEVICE = f"npu:{torch.npu.current_device()}"
        vocabulary_size = model.config.vocab_size
        topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
        topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
        temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
        next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
        next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
        torch_device = torch.npu.current_device()
    else:
        torch_device = get_device('model.layers.0.self_attn', device_map)
        torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
    inputs = inputs.to(torch_device)
    all_cuda_device = get_all_used_cuda_device(device_map)

    tokens = []

    def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
        if cuda_graph_runner is None:
            use_cuda_graph = False
        
        inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
        if use_cuda_graph:
            if cuda_graph_runner.model_capture:
                cuda_graph_runner.capture(model, cur_token, position_ids, cache_position, past_key_values, CUR_DEVICE, return_dict=False, use_cache=True)
                cuda_graph_runner.model_capture = False

            ret = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
            logits = ret[0]
            next_token = torch.argmax(logits, dim=-1)
        else:
            torch_npu.npu.set_device(torch_device)
            logits = model(inputs_embeds=inputs_embeds,
                       position_ids=position_ids,
                       cache_position=cache_position,
                       past_key_values=past_key_values,
                       return_dict=False, use_cache=True, is_prefill=False)[0]
        if past_key_values != None:
            past_key_values.change_seq_length(1)

        if generation_config.do_sample:
            logits = logits / temperature
            torch.manual_seed(0)
            probs = logits.view(batch_size, vocabulary_size)
            sm = nn.Softmax(dim=-1)
            probs = sm(probs).half().npu()
            next_token = next_token_fake
            torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
            next_token = next_token.squeeze(-1)
        else:
            next_token_scores = logits_warper(inputs, logits[:, -1, :])
            next_token = torch.argmax(next_token_scores, dim=-1)
        
        return next_token
            
    
    def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
        if use_torch_npu:
            return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
        if cuda_graph_runner is None:
            use_cuda_graph = False
        if use_cuda_graph:
            logits = cuda_graph_runner(cur_token, position_ids, cache_position)
        else:
            # custom_stream = torch.cuda.Stream()
            if torch.cuda.is_available():
                torch.cuda.set_device(torch_device)
            elif torch.xpu.is_available():
                torch.xpu.set_device(torch_device)
            elif use_torch_npu:
                torch_npu.set_device(torch_device)
            else:
                raise RuntimeError(f"The device: {torch_device} is not available")
            inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
            # with torch.cuda.stream(custom_stream):
            logits=model(inputs_embeds=inputs_embeds,
                        position_ids=position_ids,
                        cache_position=cache_position,
                        past_key_values=past_key_values,
                        return_dict=False, use_cache=True)[0]
        if past_key_values != None and isinstance(past_key_values, StaticCache):
            past_key_values.change_seq_length(1)
        sync_all_device(all_cuda_device)
        next_token_scores = logits_warper(inputs, logits[:, -1, :])
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(next_token_scores, dim=-1)
        return next_token

    # TODO: use CUDA Graph for chunk prefill, may get small improvement
    def chunk_prefill(inputs, cache_position, past_key_values):
        if mode == "long_context":
            inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
        else:
            inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
            # inputs_embeds = torch_npu.npu_format_cast_(inputs_embeds, 29)
        if use_flashinfer_mla:
            MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
            MLAWrapperSingleton.need_plan_all()

        ret = model(
            inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, is_prefill=True
        )
        logits = ret[0][:,-1,:].unsqueeze(0).clone().to(torch_device)

        return logits

    def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
        global warm_uped
        global _USE_NPU_GRAPH
        if use_cuda_graph:
            from ktransformers.util.npu_graph_runner import get_or_create_runner
            npu_graph_runner = get_or_create_runner(CUR_DEVICE)
            npu_graph_runner.init(batch_size, seq_length)
            
            with torch_npu.npu.stream(npu_graph_runner.main_stream):
                gen_num_tokens = 1
                while gen_num_tokens < max_new_tokens:
                    start_time = timeStat.record_start_time()
                    if use_flashinfer_mla:
                        MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
                                                    num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
                                                    model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
                    if gen_num_tokens == 1:
                        warm_uped = True
                        _USE_NPU_GRAPH = True
                        #np_graph_runner.capture(model, draft_model, next_token, torch.tensor(draft_token), position_ids, cache_position, past_key_values, draft_cache, torch_device, return_dict=False, use_cache=True)
                        cuda_graph_runner = npu_graph_runner
                    next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
                    next_token = next_token.to(torch_device)
                    inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
                    generated_ids[:, cache_position] = next_token.int()
                    tokens.append(int(next_token))
                    
                    seq_length += 1

                    if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
                        print(stream.end(), end="", flush=True)
                        break
                    else:
                        if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
                            print(stream.put(next_token.item()), end="", flush=True)

                    cache_position += 1
                    past_key_values.position[0] += 1
                    position_ids = cache_position.unsqueeze(0)
                    gen_num_tokens += 1
                    
                    if prof is not None:
                        prof.step()

                npu_graph_runner.destroy()
                _USE_NPU_GRAPH = False
        else:
            gen_num_tokens = 1
            while gen_num_tokens < max_new_tokens:
                if use_flashinfer_mla:
                    MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
                                                num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
                                                model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
                next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
                next_token = next_token.to(torch_device)
                inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
                generated_ids[:, cache_position] = next_token.int()
                tokens.append(int(next_token))
                seq_length += 1

                if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
                    print(stream.end(), end="", flush=True)
                    break
                else:
                    if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
                        print(stream.put(next_token.item()), end="", flush=True)

                cache_position += 1
                past_key_values.position[0] += 1
                position_ids = cache_position.unsqueeze(0)
                gen_num_tokens += 1

                if prof is not None:
                    prof.step()
        
        if prof is not None:
            prof.stop()
    
    if torch.cuda.is_available():
        torch.cuda.set_device(torch_device)
    elif torch.xpu.is_available():
        torch.xpu.set_device(torch_device)
    elif use_torch_npu:
        torch_npu.set_device(torch_device)
    else:
        raise RuntimeError(f"The device: {torch_device} is not available")

    with torch.no_grad():

        stream = TextStreamer(tokenizer)
        if torch.xpu.is_available():
            from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache
            if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
                past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
            else:
                past_key_values = DynamicNormalCache.from_legacy_cache(None)
        elif use_torch_npu and static_cache:
            assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
            past_key_values = static_cache
            if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
                print('[WARN] current staticCache size exceeded, try create new staticCache...')
                past_key_values = StaticCache(
                    config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
                )
            else:
                past_key_values.reset()
        elif mode != 'long_context':
            past_key_values = StaticCache(
                config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
            )
        else:
            past_key_values = None

        generation_config, model_kwargs = model._prepare_generation_config(
            None, do_sample=False
            # change this to modify generate config
            #top_k=5, top_p=0.85, temperature=0.1
        )
        
        logits_warper = tf_logits_warper(generation_config)

        cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
        if use_torch_npu:
            past_key_values.position[0] = seq_length + 1
        generated_ids = torch.zeros(
            batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
        )
        generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
        start_time = time.time()
        logits = None

        def prefill_wrapper(prof=None):
            nonlocal logits
            chunk_start = 0
            while chunk_start < seq_length:
                chunk_end = min(chunk_start + chunk_size, seq_length)
                if past_key_values != None:
                    past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
                logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
                chunk_start += chunk_size
                if prof is not None:
                    prof.step()
            if prof is not None:
                prof.stop()
            if logits is None:
                raise ValueError('logits cannot be None')

        if use_torch_npu:
            global WARM_UP_SKIP_CNT
            prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
            if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
                experimental_config = torch_npu.profiler._ExperimentalConfig(
                    aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
                    profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
                )
                with torch_npu.profiler.profile(
                        activities=[
                            torch_npu.profiler.ProfilerActivity.CPU,
                            torch_npu.profiler.ProfilerActivity.NPU
                        ],
                        schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
                        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
                        record_shapes=True,
                        profile_memory=True,
                        with_stack=False,
                        with_flops=False,
                        with_modules=False,
                        experimental_config=experimental_config) as prof:
                    prefill_wrapper(prof)
            else:
                prefill_wrapper()
            WARM_UP_SKIP_CNT[0] -= 1
        else:

            chunk_start = 0
            while chunk_start < seq_length:
                chunk_end = min(chunk_start + chunk_size, seq_length)
                if past_key_values != None:
                    past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
                logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
                chunk_start += chunk_size

        next_token_scores = logits_warper(inputs, logits[:, -1, :])
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(next_token_scores, dim=-1)

        first_token_time = time.time() - start_time

        # print(f"------------------------------------- prefill next_token {next_token}  draft_token {draft_token} ")
        if use_flashinfer_mla:
            MLAWrapperSingleton.reset_buffer()

        prefill_count = seq_length
        prefill_time = first_token_time
        if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
            if force_think:
                print("<think>")
            print(stream.put(next_token.item()), end="", flush=True)
        elif not use_torch_npu:
            if force_think:
                print("<think>")
            print(stream.put(next_token.item()), end="", flush=True)

        generated_ids[:, seq_length] = next_token
        tokens.append(int(next_token))
        inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
        cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
        position_ids = cache_position.unsqueeze(0)
        seq_length += 1
        
        cuda_graph_runner = None
        
        start_time = time.time()

        if not use_torch_npu:
            for i in range(1, max_new_tokens):
                if use_flashinfer_mla:
                    MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
                                                num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
                                                model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
                global warm_uped
                if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
                    warm_uped = True
                    cuda_graph_runner = CUDAGraphRunner()
                    cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
                next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
                inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
                generated_ids[:, cache_position] = next_token.int()
                tokens.append(int(next_token))
                seq_length += 1
                
                if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
                    print(stream.end(), end="", flush=True)
                    break
                else:
                    print(stream.put(next_token.item()), end="", flush=True)
                cache_position += 1
                position_ids = cache_position.unsqueeze(0)
        else:
            prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
            prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
            prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
            if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
                experimental_config = torch_npu.profiler._ExperimentalConfig(
                    aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
                    profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
                )
                with torch_npu.profiler.profile(
                        activities=[
                            torch_npu.profiler.ProfilerActivity.CPU,
                            torch_npu.profiler.ProfilerActivity.NPU
                        ],
                        schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
                        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
                        record_shapes=True,
                        profile_memory=True,
                        with_stack=False,
                        with_flops=False,
                        with_modules=False,
                        experimental_config=experimental_config) as prof:
                    decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
            else:
                decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
            WARM_UP_SKIP_CNT[1] -= 1 

    total_time = time.time() - start_time
    tokens_generated = len(tokens)
    tokens_per_second = tokens_generated / total_time

    if not use_torch_npu:
        print("")

        print(f"prompt eval count:    {prefill_count} token(s)")
        print(f"prompt eval duration: {prefill_time}s")
        print(f"prompt eval rate:     {prefill_count/prefill_time} tokens/s")
        print(f"eval count:           {tokens_generated} token(s)")
        print(f"eval duration:        {total_time}s")
        print(f"eval rate:            {tokens_per_second} tokens/s")
    else:
        tp_size = get_tensor_parallel_size()
        if torch.distributed.get_rank() % tp_size == 0:
            rank = f"[rank:{torch.distributed.get_rank()}]"
            msg = f"\n{rank} Eval Time\n"
            msg += rank + f"prompt eval count:    {prefill_count} token(s)\n"
            msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
            msg += rank + f"prompt eval rate:     {prefill_count/prefill_time:.9f} tokens/s\n"
            msg += rank + f"eval count:           {tokens_generated} token(s)\n"
            msg += rank + f"eval duration:        {total_time:.9f}s\n"
            msg += rank + f"eval rate:            {tokens_per_second:.9f} tokens/s\n"
            print(msg)

    return tokens

class InferenceState(enum.Enum):
    UNLOAD = 0
    PREFILL = 1
    GENERATE = 2
    RESTORE = 3
